{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 计算分配内存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import tir\n",
    "from tvm.script import tir as T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class Module:\n",
    "    @T.prim_func\n",
    "    def scale_by_two(a: T.Buffer((128,), \"int8\"), c: T.Buffer((128,), \"int8\")):\n",
    "        for i in T.serial(128):\n",
    "            with T.block(\"C\"):\n",
    "                c[i] = a[i] * T.int8(2)\n",
    "\n",
    "\n",
    "    @T.prim_func\n",
    "    def scale_by_two_three(a: T.Buffer((128,), \"int8\"), c: T.Buffer((128,), \"int8\")):\n",
    "        B = T.alloc_buffer([128], dtype=\"int8\", scope=\"global.vtcm\")\n",
    "        for i in T.serial(128):\n",
    "            with T.block(\"B\"):\n",
    "                B[i] = a[i] * T.int8(2)\n",
    "        for i in T.serial(128):\n",
    "            with T.block(\"C\"):\n",
    "                c[i] = B[i] * T.int8(3)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 计算每个 scope 内存分配"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "for primFunc, size in [(Module[\"scale_by_two\"], 128), (Module[\"scale_by_two_three\"], 256)]:\n",
    "    mod = tvm.IRModule.from_expr(primFunc.with_attr(\"global_symbol\", \"main\"))\n",
    "    sch = tir.Schedule(mod, debug_mask=\"all\")\n",
    "    block_c = sch.get_block(\"C\")\n",
    "    (flat,) = sch.get_loops(block_c)\n",
    "    cache_block = sch.cache_read(block_c, 0, storage_scope=\"global.vtcm\")\n",
    "    sch.compute_at(cache_block, flat)\n",
    "\n",
    "    mod = sch.mod\n",
    "    mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)\n",
    "    mod = tvm.tir.transform.LowerOpaqueBlock()(mod)\n",
    "    sizes = tvm.tir.analysis.calculate_allocated_bytes(mod[\"main\"])\n",
    "    assert \"main\" in sizes, 'Calls with PrimFunc is expected to return with function key as \"main\"'\n",
    "    sizes = sizes[\"main\"]\n",
    "    assert sizes.get(\"global.vtcm\", 0) == size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 计算混合 scope 的内存分配\n",
    "\n",
    "test_scale_by"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "@T.prim_func\n",
    "def matmul_mix_scope(a: T.handle, b: T.handle, c: T.handle) -> None:\n",
    "    A = T.match_buffer(a, [128, 128], scope=\"global\")\n",
    "    B = T.match_buffer(b, [128, 128], scope=\"global\")\n",
    "    C = T.match_buffer(c, [128, 128], scope=\"global\")\n",
    "    A_allocated = T.alloc_buffer([128, 128], dtype=\"float32\", scope=\"global.texture\")\n",
    "    B_allocated = T.alloc_buffer([128, 128], dtype=\"float32\", scope=\"global.texture\")\n",
    "    C_allocated = T.alloc_buffer([128, 128], dtype=\"float32\", scope=\"global\")\n",
    "\n",
    "    for i, j in T.grid(128, 128):\n",
    "        with T.block(\"A.allocated\"):\n",
    "            A_allocated[i, j] = A[i, j]\n",
    "    for i, j in T.grid(128, 128):\n",
    "        with T.block(\"B.allocated\"):\n",
    "            B_allocated[i, j] = B[i, j]\n",
    "\n",
    "    for i, j, k in T.grid(128, 128, 128):\n",
    "        with T.block(\"update\"):\n",
    "            vi, vj, vk = T.axis.remap(\"SSR\", [i, j, k])\n",
    "            with T.init():\n",
    "                C_allocated[vi, vj] = 0.0\n",
    "            C_allocated[vi, vj] = C[vi, vj] + A_allocated[vi, vk] * B_allocated[vj, vk]\n",
    "\n",
    "    for i, j in T.grid(128, 128):\n",
    "        with T.block(\"C\"):\n",
    "            C[i, j] = C_allocated[i, j]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "groups = [(\"global\", 65536), (\"global.texture\", 131072), (\"global.texture-nhwc\", 0)]\n",
    "\n",
    "for scope, size in groups:\n",
    "    mod = tvm.IRModule({\"main\": matmul_mix_scope})\n",
    "    mod = tvm.tir.transform.LowerInitBlock()(mod)\n",
    "    mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)\n",
    "    mod = tvm.tir.transform.LowerOpaqueBlock()(mod)\n",
    "    sizes = tvm.tir.analysis.calculate_allocated_bytes(mod[\"main\"])\n",
    "    assert \"main\" in sizes, 'Calls with PrimFunc is expected to return with function key as \"main\"'\n",
    "    sizes = sizes[\"main\"]\n",
    "    assert sizes.get(scope, 0) == size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## full_mod_calculator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_schedule(sch, func_name):\n",
    "    sch.work_on(func_name)\n",
    "    block_c = sch.get_block(\"C\")\n",
    "    sch.cache_read(block_c, 0, storage_scope=\"global.vtcm\")\n",
    "\n",
    "sch = tvm.tir.Schedule(Module, debug_mask=\"all\")\n",
    "apply_schedule(sch, \"scale_by_two\")\n",
    "apply_schedule(sch, \"scale_by_two_three\")\n",
    "mod = tvm.tir.transform.ConvertBlocksToOpaque()(sch.mod)\n",
    "mod = tvm.tir.transform.LowerOpaqueBlock()(mod)\n",
    "sizes = tvm.tir.analysis.calculate_allocated_bytes(mod)\n",
    "assert \"scale_by_two\" in sizes, \"Values for scale_by_two not found\"\n",
    "scale_by_two_sizes = sizes[\"scale_by_two\"]\n",
    "assert (\n",
    "    \"global.vtcm\" in scale_by_two_sizes\n",
    "), \"Expected global.vtcm allocation to be calculated scale_by_two\"\n",
    "assert scale_by_two_sizes[\"global.vtcm\"] == 128, \"Expected the calculated size to be 128\"\n",
    "scale_by_two_three_sizes = sizes[\"scale_by_two_three\"]\n",
    "assert (\n",
    "    \"global.vtcm\" in scale_by_two_three_sizes\n",
    "), \"Expected global.vtcm allocation to be calculated scale_by_two_three\"\n",
    "assert scale_by_two_three_sizes[\"global.vtcm\"] == 256, \"Expected the calculated size to be 256\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py311",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
